name: Install env + build
inputs:
  arch:
    description: 'GPU architecture'
    required: true
  python:
    description: 'Python version'
    required: false
    default: "3.11"
  cuda:
    description: 'CUDA version'
    required: false
    default: "11.8"
  pytorch_version:
    description: 'PyTorch version'
    default: "2"
  pytorch_channel:
    description: 'PyTorch channel on conda'
    default: "pytorch"

runs:
  using: composite
  steps:
    - name: Cleanup
      shell: bash
      run: rm -f ~/.profile ~/.bashrc
    - id: prepare_conda_env_paths
      shell: python
      run: |
        import os
        import subprocess
        import hashlib
        import glob
        import datetime
        from pathlib import Path

        CONDA_INSTALL_CMD = "conda create python=${{ inputs.python }} zlib pip ninja pytorch=${{ inputs.pytorch_version }} torchvision ccache=4.8 pytorch-cuda=${{ inputs.cuda }} -c ${{ inputs.pytorch_channel }} -c nvidia -c conda-forge -q -y"

        conda_env_key = CONDA_INSTALL_CMD
        for file in sorted(glob.glob("requirement*.txt")):
          conda_env_key += f"\n########## {file}\n"
          conda_env_key += Path(file).read_text()
        if "${{ inputs.pytorch_channel }}" != "pytorch":
          # Nightly or Test, update every week
          conda_env_key += datetime.date.today().strftime("%Y-week%W") + "\n"
        conda_env_hash = hashlib.sha224(conda_env_key.encode("ascii")).hexdigest()[:8]
        shared_dir = os.environ.get("GHRUNNER_SHARED_DIR", os.getcwd())
        env_path = os.path.join(shared_dir, "tmp", os.environ["GITHUB_RUN_ID"])
        final_env = Path(shared_dir) / f"env_{conda_env_hash}.txt"
        (Path(shared_dir) / f"env_{conda_env_hash}_content.txt").write_text(conda_env_key)
        CONDA_INSTALL_CMD += " -p " + env_path
        env_already_built = False
        # If environment is already built
        if final_env.is_file():
          final_env_link = final_env.read_text().strip()
          if (Path(final_env_link) / "bin" / "python").is_file():
            print("Found valid env - skipping env setup")
            CONDA_INSTALL_CMD = "true"
            env_already_built = True
            env_path = final_env_link
          else:
            print("Invalid env")
        with open(os.environ['GITHUB_ENV'], "r+") as fp:
            fp.write("CONDA_ENV_LINK=" + str(final_env) + "\n")
            fp.write("CONDA_PREFIX=" + env_path + "\n")
            fp.write("CONDA_INSTALL_CMD=" + CONDA_INSTALL_CMD + "\n")
            fp.write("CONDA_ENV_HASH=" + conda_env_hash + "\n")
            fp.write("PY=" + os.path.join(env_path, "bin", "python") + "\n")
            fp.write("PIP=" + os.path.join(env_path, "bin", "pip") + "\n")
        with open(os.environ['GITHUB_OUTPUT'], "r+") as fp:
          fp.write(f"ENV_CACHED={int(env_already_built)}\n")
    - name: Print conda commands
      shell: bash -l {0}
      run: |
        echo "CONDA_PREFIX=$CONDA_PREFIX"
        echo "CONDA_INSTALL_CMD=$CONDA_INSTALL_CMD"
        echo "CONDA_ENV_HASH=$CONDA_ENV_HASH"
        echo "PY=$PY"
    - name: Conda/pip setup
      shell: bash -l {0}
      if: steps.prepare_conda_env_paths.outputs.ENV_CACHED == 0
      run: |
        set -ex
        conda config --set channel_priority strict
        # Retry if failed after removing downloaded packages cache
        $CONDA_INSTALL_CMD || (rm -rf $HOME/.conda/pkgs && rm -rf $CONDA_PREFIX && $CONDA_INSTALL_CMD)
        $PY -m pip install cmake
        $PY -m pip install -r requirements-benchmark.txt --progress-bar off
    - name: Activate environment
      shell: bash -l {0}
      run: |
        echo "" > ~/.bashrc
        conda init
        echo "conda activate $CONDA_PREFIX" >> ~/.bashrc
        # For some reason bash only reads from `.profile`
        # but conda writes init script to `.bashrc`
        cat ~/.bashrc >> ~/.profile
        rm ~/.bashrc
        echo "==== .profile ====="
        cat ~/.profile
    - run: which python
      shell: bash -l {0}
    - name: Setup ccache nvcc
      shell: bash -l {0}
      if: steps.prepare_conda_env_paths.outputs.ENV_CACHED == 0
      run: |
        echo "#!/bin/bash" > $CONDA_PREFIX/bin/nvcc-ccache
        echo "ccache ${CUDA_HOME}/bin/nvcc \"\$@\"" >> $CONDA_PREFIX/bin/nvcc-ccache
        cat $CONDA_PREFIX/bin/nvcc-ccache
        chmod +x $CONDA_PREFIX/bin/nvcc-ccache
        which nvcc
        ccache --version
        
    - name: Setup ccache g++
      shell: bash -l {0}
      if: steps.prepare_conda_env_paths.outputs.ENV_CACHED == 0
      run: |
        echo "#!/bin/bash" > $CONDA_PREFIX/bin/g++-ccache
        echo "ccache g++ \"\$@\"" >> $CONDA_PREFIX/bin/g++-ccache
        cat $CONDA_PREFIX/bin/g++-ccache
        chmod +x $CONDA_PREFIX/bin/g++-ccache
        which g++-ccache
  
    - name: Patch for https://github.com/pytorch/pytorch/issues/114962
      shell: bash -l {0}
      run: |
        CPP_EXTENSIONS_PY=$(python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.__file__)")
        echo "Patching $CPP_EXTENSIONS_PY"
        sed -i "/generate-dependencies-with-compile/d" $CPP_EXTENSIONS_PY
    - name: Check NVIDIA libs
      shell: bash -l {0}
      run: |
        ldconfig -p | grep libcuda.so
        ls /.singularity.d/libs/
    - name: Mark env as ready
      shell: bash -l {0}
      if: steps.prepare_conda_env_paths.outputs.ENV_CACHED == 0
      run: echo $CONDA_PREFIX > $CONDA_ENV_LINK
    - name: Setup ccache
      shell: bash -l {0}
      run: |
        export CCACHE_DIR=$GHRUNNER_SHARED_DIR/ccache
        echo "CCACHE_DIR=$CCACHE_DIR" >> ${GITHUB_ENV}
        mkdir -p $CCACHE_DIR
        ccache -s
    - name: Build
      shell: bash -l {0}
      run: |
        PYTORCH_NVCC="$CONDA_PREFIX/bin/nvcc-ccache" CXX="g++-ccache" TORCH_CUDA_ARCH_LIST=${{ inputs.arch }} python -m pip install -v -e .
    - name: Build info
      run: |
        printenv
        python -m xformers.info
        python xformers/_triton_version_fairinternal.py
        ccache -s
      shell: bash -l {0}
